{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Parallel training with Thinc and Ray\n",
    "\n",
    "This notebook is based off one of [Ray's tutorials](https://ray.readthedocs.io/en/latest/auto_examples/plot_parameter_server.html) and shows how to use Thinc and Ray to implement parallel training. It includes implementations for both synchronous and asynchronous parameter server training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# To let ray install its own version in Colab\n",
    "!pip uninstall -y pyarrow\n",
    "# You might need to restart the Colab runtime"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install --upgrade \"thinc>=8.0.0a0\" \"ml_datasets>=0.2.0a0\" ray psutil setproctitle"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's start with a simple model and [config file](https://thinc.ai/docs/usage-config). You can edit the `CONFIG` string within the file, or copy it out to a separate file and use `Config.from_disk` to load it from a path. The `[ray]` section contains the settings to use for Ray. (We're using a config for convenience, but you don't have to – you can also just hard-code the values.)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import thinc\n",
    "from thinc.api import chain, Relu, Softmax\n",
    "\n",
    "@thinc.registry.layers(\"relu_relu_softmax.v1\")\n",
    "def make_relu_relu_softmax(hidden_width: int, dropout: float):\n",
    "    return chain(\n",
    "        Relu(hidden_width, dropout=dropout),\n",
    "        Relu(hidden_width, dropout=dropout),\n",
    "        Softmax(),\n",
    "    )\n",
    "\n",
    "CONFIG = \"\"\"\n",
    "[training]\n",
    "iterations = 200\n",
    "batch_size = 128\n",
    "\n",
    "[evaluation]\n",
    "batch_size = 256\n",
    "frequency = 10\n",
    "\n",
    "[model]\n",
    "@layers = \"relu_relu_softmax.v1\"\n",
    "hidden_width = 128\n",
    "dropout = 0.2\n",
    "\n",
    "[optimizer]\n",
    "@optimizers = \"Adam.v1\"\n",
    "\n",
    "[ray]\n",
    "num_workers = 2\n",
    "object_store_memory = 3000000000\n",
    "num_cpus = 2\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Just like in the original Ray tutorial, we're using the MNIST data (via our [`ml-datasets`](https://github.com/explosion/ml-datasets) package) and are setting up two helper functions: \n",
    "\n",
    "1. `get_data_loader`: Return shuffled batches of a given batch size.\n",
    "2. `evaluate`: Evaluate a model on batches of data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import ml_datasets\n",
    "\n",
    "MNIST = ml_datasets.mnist()\n",
    "\n",
    "def get_data_loader(model, batch_size):\n",
    "    (train_X, train_Y), (dev_X, dev_Y) = MNIST\n",
    "    train_batches = model.ops.multibatch(batch_size, train_X, train_Y, shuffle=True)\n",
    "    dev_batches = model.ops.multibatch(batch_size, dev_X, dev_Y, shuffle=True)\n",
    "    return train_batches, dev_batches\n",
    "\n",
    "def evaluate(model, batch_size):\n",
    "    dev_X, dev_Y = MNIST[1]\n",
    "    correct = 0\n",
    "    total = 0\n",
    "    for X, Y in model.ops.multibatch(batch_size, dev_X, dev_Y):\n",
    "        Yh = model.predict(X)\n",
    "        correct += (Yh.argmax(axis=1) == Y.argmax(axis=1)).sum()\n",
    "        total += Yh.shape[0]\n",
    "    return correct / total"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "\n",
    "## Setting up Ray\n",
    "\n",
    "### Getters and setters for gradients and weights\n",
    "\n",
    "Using Thinc's `Model.walk` method, we can implement the following helper functions to get and set weights and parameters for each node in a model's tree. Those functions can later be used by the parameter server and workers."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import defaultdict\n",
    "\n",
    "def get_model_weights(model):\n",
    "    params = defaultdict(dict)\n",
    "    for node in model.walk():\n",
    "        for name in node.param_names:\n",
    "            if node.has_param(name):\n",
    "                params[node.id][name] = node.get_param(name)\n",
    "    return params\n",
    "\n",
    "def set_model_weights(model, params):\n",
    "    for node in model.walk():\n",
    "        for name, param in params[node.id].items():\n",
    "            node.set_param(name, param)\n",
    "\n",
    "def get_model_grads(model):\n",
    "    grads = defaultdict(dict)\n",
    "    for node in model.walk():\n",
    "        for name in node.grad_names:\n",
    "            grads[node.id][name] = node.get_grad(name)\n",
    "    return grads\n",
    "\n",
    "def set_model_grads(model, grads):\n",
    "    for node in model.walk():\n",
    "        for name, grad in grads[node.id].items():\n",
    "            node.set_grad(name, grad)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Defining the Parameter Server\n",
    "\n",
    "> The parameter server will hold a copy of the model. During training, it will:\n",
    ">\n",
    "> 1. Receive gradients and apply them to its model.\n",
    "> 2. Send the updated model back to the workers.\n",
    ">\n",
    "> The `@ray.remote` decorator defines a remote process. It wraps the `ParameterServer `class and allows users to instantiate it as a remote actor. ([Source](https://ray.readthedocs.io/en/latest/auto_examples/plot_parameter_server.html#defining-the-parameter-server))\n",
    "\n",
    "Here, the `ParameterServer` is initialized with a model and optimizer, and has a method to apply gradients received by the workers and a method to get the weights from the current model, using the helper functions defined above."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import ray\n",
    "\n",
    "@ray.remote\n",
    "class ParameterServer:\n",
    "    def __init__(self, model, optimizer):\n",
    "        self.model = model\n",
    "        self.optimizer = optimizer\n",
    "\n",
    "    def apply_gradients(self, *worker_grads):\n",
    "        summed_gradients = defaultdict(dict)\n",
    "        for grads in worker_grads:\n",
    "            for node_id, node_grads in grads.items():\n",
    "                for name, grad in node_grads.items():\n",
    "                    if name in summed_gradients[node_id]:\n",
    "                        summed_gradients[node_id][name] += grad\n",
    "                    else:\n",
    "                        summed_gradients[node_id][name] = grad.copy()\n",
    "        set_model_grads(self.model, summed_gradients)\n",
    "        self.model.finish_update(self.optimizer)\n",
    "        return get_model_weights(self.model)\n",
    "\n",
    "    def get_weights(self):\n",
    "        return get_model_weights(self.model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Defining the Worker\n",
    "\n",
    "> The worker will also hold a copy of the model. During training it will continuously evaluate data and send gradients to the parameter server. The worker will synchronize its model with the Parameter Server model weights. ([Source](https://ray.readthedocs.io/en/latest/auto_examples/plot_parameter_server.html#defining-the-worker))\n",
    "\n",
    "To compute the gradients during training, we can call the model on a batch of data (and set `is_train=True`). This returns the predictions and a `backprop` callback to update the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from thinc.api import fix_random_seed\n",
    "\n",
    "@ray.remote\n",
    "class DataWorker:\n",
    "    def __init__(self, model, batch_size=128, seed=0):\n",
    "        self.model = model\n",
    "        fix_random_seed(seed)\n",
    "        self.data_iterator = iter(get_data_loader(model, batch_size)[0])\n",
    "        self.batch_size = batch_size\n",
    "\n",
    "    def compute_gradients(self, weights):\n",
    "        set_model_weights(self.model, weights)\n",
    "        try:\n",
    "            data, target = next(self.data_iterator)\n",
    "        except StopIteration:  # When the epoch ends, start a new epoch.\n",
    "            self.data_iterator = iter(get_data_loader(model, self.batch_size)[0])\n",
    "            data, target = next(self.data_iterator)\n",
    "        guesses, backprop = self.model(data, is_train=True)\n",
    "        backprop((guesses - target) / target.shape[0])\n",
    "        return get_model_grads(self.model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Setting up the model\n",
    "\n",
    "Using the `CONFIG` defined above, we can load the settings and set up the model and optimizer. Thinc's `registry.resolve` will parse the config, resolve all references to registered functions and return a dict of the resolved objects."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from thinc.api import registry, Config\n",
    "C = registry.resolve(Config().from_str(CONFIG))\n",
    "C"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We didn't specify all the dimensions in the model, so we need to pass in a batch of data to finish initialization. This lets Thinc infer the missing shapes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer = C[\"optimizer\"]\n",
    "model = C[\"model\"]\n",
    "\n",
    "(train_X, train_Y), (dev_X, dev_Y) = MNIST\n",
    "model.initialize(X=train_X[:5], Y=train_Y[:5])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "\n",
    "## Training\n",
    "\n",
    "### Synchronous Parameter Server training\n",
    "\n",
    "We can now create a synchronous parameter server training scheme:\n",
    "\n",
    "1. Call `ray.init` with the settings defined in the config.\n",
    "2. Instantiate a process for the `ParameterServer`.\n",
    "3. Create multiple workers (`n_workers`, as defined in the config).\n",
    "\n",
    "\n",
    "Though this is not specifically mentioned in the Ray tutorial, we're setting a different random seed for the workers here.\n",
    "Otherwise the workers may iterate over the batches in the same order."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ray.init(\n",
    "    ignore_reinit_error=True,\n",
    "    object_store_memory=C[\"ray\"][\"object_store_memory\"],\n",
    "    num_cpus=C[\"ray\"][\"num_cpus\"],\n",
    ")\n",
    "ps = ParameterServer.remote(model, optimizer)\n",
    "workers = []\n",
    "for i in range(C[\"ray\"][\"num_workers\"]):\n",
    "    worker = DataWorker.remote(model, batch_size=C[\"training\"][\"batch_size\"], seed=i)\n",
    "    workers.append(worker)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "On each iteration, we now compute the gradients for each worker. After all gradients are available, `ParameterServer.apply_gradients` is called to calculate the update. The `frequency` setting in the `evaluation` config specifies how often to evaluate – for instance, a frequency of `10` means we're only evaluating every 10th epoch. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "current_weights = ps.get_weights.remote()\n",
    "for i in range(C[\"training\"][\"iterations\"]):\n",
    "    gradients = [worker.compute_gradients.remote(current_weights) for worker in workers]\n",
    "    current_weights = ps.apply_gradients.remote(*gradients)\n",
    "    if i % C[\"evaluation\"][\"frequency\"] == 0:\n",
    "        set_model_weights(model, ray.get(current_weights))\n",
    "        accuracy = evaluate(model, C[\"evaluation\"][\"batch_size\"])\n",
    "        print(f\"{i} \\taccuracy: {accuracy:.3f}\")\n",
    "print(f\"Final \\taccuracy: {accuracy:.3f}\")\n",
    "ray.shutdown()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Asynchronous Parameter Server Training\n",
    "\n",
    "> Here, workers will asynchronously compute the gradients given its current weights and send these gradients to the parameter server as soon as they are ready. When the Parameter server finishes applying the new gradient, the server will send back a copy of the current weights to the worker. The worker will then update the weights and repeat. ([Source](https://ray.readthedocs.io/en/latest/auto_examples/plot_parameter_server.html#asynchronous-parameter-server-training))\n",
    "\n",
    "The setup looks the same and we can reuse the config. Make sure to call `ray.shutdown()` to clean up resources and processes before calling `ray.init` again."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ray.init(\n",
    "    ignore_reinit_error=True,\n",
    "    object_store_memory=C[\"ray\"][\"object_store_memory\"],\n",
    "    num_cpus=C[\"ray\"][\"num_cpus\"],\n",
    ")\n",
    "ps = ParameterServer.remote(model, optimizer)\n",
    "workers = []\n",
    "for i in range(C[\"ray\"][\"num_workers\"]):\n",
    "    worker = DataWorker.remote(model, batch_size=C[\"training\"][\"batch_size\"], seed=i)\n",
    "    workers.append(worker)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "current_weights = ps.get_weights.remote()\n",
    "gradients = {}\n",
    "for worker in workers:\n",
    "    gradients[worker.compute_gradients.remote(current_weights)] = worker\n",
    "\n",
    "for i in range(C[\"training\"][\"iterations\"] * C[\"ray\"][\"num_workers\"]):\n",
    "    ready_gradient_list, _ = ray.wait(list(gradients))\n",
    "    ready_gradient_id = ready_gradient_list[0]\n",
    "    worker = gradients.pop(ready_gradient_id)\n",
    "    current_weights = ps.apply_gradients.remote(*[ready_gradient_id])\n",
    "    gradients[worker.compute_gradients.remote(current_weights)] = worker\n",
    "    if i % C[\"evaluation\"][\"frequency\"] == 0:\n",
    "        set_model_weights(model, ray.get(current_weights))\n",
    "        accuracy = evaluate(model, C[\"evaluation\"][\"batch_size\"])\n",
    "        print(f\"{i} \\taccuracy: {accuracy:.3f}\")\n",
    "print(f\"Final \\taccuracy: {accuracy:.3f}\")\n",
    "ray.shutdown()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "\n",
    "## Links & Resources\n",
    "\n",
    "- [Ray documentation](https://ray.readthedocs.io/en/latest/index.html)\n",
    "- [Training models](https://thinc.ai/docs/usage-training) (Thinc)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.7.5 64-bit ('env3.7': venv)",
   "language": "python",
   "name": "python37564bitenv37venv23914e112e9949feb5f1b5cfd33771f7"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5-final"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
